"""
Author : Mr.Sun
Datetime : 2024/12/6 17:40 
FileName : ski_learn_key.py
Desc : 
"""
import jieba
from typing import List, Dict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder


class SemanticClassifier:
    def __init__(self, categories: Dict[str, List[str]] = None):
        """
        语义分类器初始化

        :param categories: 预定义的类别和对应的关键词
        """
        # 默认类别定义
        self.default_categories = {
            "校园招聘": ["学院", "大学", "应届生", "研究生", "专业学校", "校园招聘","学校", "大学", "学院", "博士", "教师"],
            "地方": ["省", "市", "区", "县", "乡", "镇"],
            "政府": ["政府", "机关", "事业单位", "公务员", "编制"]
        }

        # 合并默认类别和自定义类别
        self.categories = categories or self.default_categories

        # 类别标签编码
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(list(self.categories.keys()))

        # 构建训练语料
        self.train_corpus = self._build_train_corpus()

        # TF-IDF向量化
        self.vectorizer = TfidfVectorizer()
        self.corpus_vectors = self.vectorizer.fit_transform(self.train_corpus)

    def _build_train_corpus(self) -> List[str]:
        """
        构建训练语料库
        """
        train_data = []
        for category, keywords in self.categories.items():
            train_data.extend([f"{category} {' '.join(keywords)}" for _ in range(5)])
        return train_data

    def semantic_classify(self, text: str, top_k: int = 3) -> List[Dict]:
        """
        语义分类方法

        :param text: 待分类文本
        :param top_k: 返回置信度最高的前k个类别
        :return: 分类结果列表
        """
        # 分词处理
        words = jieba.lcut(text)
        print(words,11111)
        text_vector = self.vectorizer.transform([' '.join(words)])

        # 计算余弦相似度
        similarities = cosine_similarity(text_vector, self.corpus_vectors)[0]

        # 按类别聚合相似度
        category_scores = {}
        for category, score in zip(self.train_corpus, similarities):
            cat_name = category.split()[0]
            category_scores[cat_name] = max(category_scores.get(cat_name, 0), score)

        # 对结果排序
        sorted_categories = sorted(
            category_scores.items(),
            key=lambda x: x[1],
            reverse=True
        )

        # 返回详细结果
        results = []
        for category, confidence in sorted_categories[:top_k]:
            if confidence > 0.1:  # 置信度阈值
                results.append({
                    "category": category,
                    "confidence": round(confidence, 2)
                })

        return results

    def match_category_by_rules(self, text: str) -> str:
        """
        基于规则的快速匹配
        """
        for category, keywords in self.categories.items():
            if any(keyword in text for keyword in keywords):
                return category
        return "未知"

    def advanced_classify(self, text: str) -> Dict:
        """
        综合语义分类方法
        """
        # 快速规则匹配
        rule_match = self.match_category_by_rules(text)

        # 语义深度分类
        semantic_results = self.semantic_classify(text)

        return {
            "rule_match": rule_match,
            "semantic_match": semantic_results
        }


if __name__ == "__main__":
    classifier = SemanticClassifier()
    result = classifier.advanced_classify("2024年四川成都中医药大学药学院/现代中药产业学院招聘科研助理3人公告")
    print(f"规则匹配: {result['rule_match']}")
    print(f"语义匹配: {result['semantic_match']}\n")
